# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""End-to-end tests that check model correctness."""

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import os
import tempfile
import unittest

import numpy as np
import tensorflow as tf

# pylint: disable=g-bad-import-order
from tfltransfer import bases
from tfltransfer import optimizers
from tfltransfer import heads
from tfltransfer import tflite_transfer_converter
# pylint: enable=g-bad-import-order

IMAGE_SIZE = 224
BATCH_SIZE = 128
NUM_CLASSES = 5
VALIDATION_SPLIT = 0.2
LEARNING_RATE = 0.001
BOTTLENECK_SHAPE = (7, 7, 1280)

DATASET_URL = 'https://storage.googleapis.com/download.tensorflow.org/example_images/flower_photos.tgz'


class TransferModel(object):
  """Test consumer of models generated by the converter."""

  def __init__(self, dataset_dir, base_model, head_model, optimizer):
    """Creates a wrapper for a set of models and a data set."""
    self.dataset_dir = dataset_dir

    datagen = tf.keras.preprocessing.image.ImageDataGenerator(
        rescale=1. / 255, validation_split=VALIDATION_SPLIT)
    self.train_img_generator = datagen.flow_from_directory(
        self.dataset_dir,
        target_size=(IMAGE_SIZE, IMAGE_SIZE),
        batch_size=BATCH_SIZE,
        subset='training')
    self.val_img_generator = datagen.flow_from_directory(
        self.dataset_dir,
        target_size=(IMAGE_SIZE, IMAGE_SIZE),
        batch_size=BATCH_SIZE,
        subset='validation')

    converter = tflite_transfer_converter.TFLiteTransferConverter(
        NUM_CLASSES, base_model, head_model, optimizer, BATCH_SIZE)
    models = converter._convert()
    self.initialize_model = models['initialize']
    self.bottleneck_model = models['bottleneck']
    self.train_head_model = models['train_head']
    self.inference_model = models['inference']
    self.optimizer_model = models['optimizer']
    self.variables = self._generate_initial_variables()

    optim_state_shapes = self._optimizer_state_shapes()
    self.optim_state = [
        np.zeros(shape, dtype=np.float32) for shape in optim_state_shapes
    ]

  def _generate_initial_variables(self):
    """Generates the initial model variables."""
    interpreter = tf.lite.Interpreter(model_content=self.initialize_model)
    zero_in = interpreter.get_input_details()[0]
    variable_outs = interpreter.get_output_details()
    interpreter.allocate_tensors()
    interpreter.set_tensor(zero_in['index'], np.float32(0.))
    interpreter.invoke()
    return [interpreter.get_tensor(var['index']) for var in variable_outs]

  def _optimizer_state_shapes(self):
    """Reads the shapes of the optimizer parameters (mutable state)."""
    interpreter = tf.lite.Interpreter(model_content=self.optimizer_model)
    num_variables = len(self.variables)
    optim_state_inputs = interpreter.get_input_details()[num_variables * 2:]
    return [input_['shape'] for input_ in optim_state_inputs]

  def prepare_bottlenecks(self):
    """Passes all images through the base model and save the bottlenecks.

    This method has to be called before any training or inference.
    """
    self.train_bottlenecks, self.train_labels = (
        self._collect_and_generate_bottlenecks(self.train_img_generator))
    self.val_bottlenecks, self.val_labels = (
        self._collect_and_generate_bottlenecks(self.val_img_generator))

  def _collect_and_generate_bottlenecks(self, image_gen):
    """Consumes a generator and converts all images to bottlenecks.

    Args:
      image_gen: A Keras data generator for images to process

    Returns:
      Two NumPy arrays: (bottlenecks, labels).
    """
    collected_bottlenecks = np.zeros(
        (image_gen.samples,) + BOTTLENECK_SHAPE, dtype=np.float32)
    collected_labels = np.zeros((image_gen.samples, NUM_CLASSES),
                                dtype=np.float32)

    next_idx = 0
    for bottlenecks, truth in self._generate_bottlenecks(
        make_finite(image_gen)):
      batch_size = bottlenecks.shape[0]
      collected_bottlenecks[next_idx:next_idx + batch_size] = bottlenecks
      collected_labels[next_idx:next_idx + batch_size] = truth
      next_idx += batch_size

    return collected_bottlenecks, collected_labels

  def _generate_bottlenecks(self, image_gen):
    """Generator adapter that passes images through the bottleneck model.

    Args:
      image_gen: A generator that returns images to be processed. Images are
        paired with ground truth labels.

    Yields:
      Bottlenecks from input images, paired with ground truth labels.
    """
    interpreter = tf.lite.Interpreter(model_content=self.bottleneck_model)
    [x_in] = interpreter.get_input_details()
    [bottleneck_out] = interpreter.get_output_details()

    for (x, y) in image_gen:
      batch_size = x.shape[0]
      interpreter.resize_tensor_input(x_in['index'],
                                      (batch_size, IMAGE_SIZE, IMAGE_SIZE, 3))
      interpreter.allocate_tensors()
      interpreter.set_tensor(x_in['index'], x)
      interpreter.invoke()
      bottleneck = interpreter.get_tensor(bottleneck_out['index'])
      yield bottleneck, y

  def train_head(self, num_epochs):
    """Trains the head model for a given number of epochs.

    SGD is used as an optimizer.

    Args:
      num_epochs: how many epochs should be trained

    Returns:
      A list of train_loss values after every epoch trained.

    Raises:
      RuntimeError: when prepare_bottlenecks() has not been called.
    """
    if not hasattr(self, 'train_bottlenecks'):
      raise RuntimeError('prepare_bottlenecks has not been called')
    results = []
    for _ in range(num_epochs):
      loss = self._train_one_epoch(
          self._generate_batches(self.train_bottlenecks, self.train_labels))
      results.append(loss)
    return results

  def _generate_batches(self, x, y):
    """Creates a generator that iterates over the data in batches."""
    num_total = x.shape[0]
    for begin in range(0, num_total, BATCH_SIZE):
      end = min(begin + BATCH_SIZE, num_total)
      yield x[begin:end], y[begin:end]

  def _train_one_epoch(self, train_gen):
    """Performs one training epoch."""
    interpreter = tf.lite.Interpreter(model_content=self.train_head_model)
    interpreter.allocate_tensors()
    x_in, y_in = interpreter.get_input_details()[:2]
    variable_ins = interpreter.get_input_details()[2:]
    loss_out = interpreter.get_output_details()[0]
    gradient_outs = interpreter.get_output_details()[1:]

    epoch_loss = 0.
    num_processed = 0
    for bottlenecks, truth in train_gen:
      batch_size = bottlenecks.shape[0]
      if batch_size < BATCH_SIZE:
        bottlenecks = pad_batch(bottlenecks, BATCH_SIZE)
        truth = pad_batch(truth, BATCH_SIZE)

      interpreter.set_tensor(x_in['index'], bottlenecks)
      interpreter.set_tensor(y_in['index'], truth)
      for variable_in, variable_value in zip(variable_ins, self.variables):
        interpreter.set_tensor(variable_in['index'], variable_value)
      interpreter.invoke()

      loss = interpreter.get_tensor(loss_out['index'])
      gradients = [
          interpreter.get_tensor(gradient_out['index'])
          for gradient_out in gradient_outs
      ]

      self._apply_gradients(gradients)
      epoch_loss += loss * batch_size
      num_processed += batch_size

    epoch_loss /= num_processed
    return epoch_loss

  def _apply_gradients(self, gradients):
    """Applies the optimizer to the model parameters."""
    interpreter = tf.lite.Interpreter(model_content=self.optimizer_model)
    interpreter.allocate_tensors()
    num_variables = len(self.variables)
    variable_ins = interpreter.get_input_details()[:num_variables]
    gradient_ins = interpreter.get_input_details()[num_variables:num_variables *
                                                   2]
    state_ins = interpreter.get_input_details()[num_variables * 2:]
    variable_outs = interpreter.get_output_details()[:num_variables]
    state_outs = interpreter.get_output_details()[num_variables:]

    for variable, gradient, variable_in, gradient_in in zip(
        self.variables, gradients, variable_ins, gradient_ins):
      interpreter.set_tensor(variable_in['index'], variable)
      interpreter.set_tensor(gradient_in['index'], gradient)

    for optim_state_elem, state_in in zip(self.optim_state, state_ins):
      interpreter.set_tensor(state_in['index'], optim_state_elem)

    interpreter.invoke()
    self.variables = [
        interpreter.get_tensor(variable_out['index'])
        for variable_out in variable_outs
    ]
    self.optim_state = [
        interpreter.get_tensor(state_out['index']) for state_out in state_outs
    ]

  def measure_inference_accuracy(self):
    """Runs the inference model and measures accuracy on the validation set."""
    interpreter = tf.lite.Interpreter(model_content=self.inference_model)
    bottleneck_in = interpreter.get_input_details()[0]
    variable_ins = interpreter.get_input_details()[1:]
    [y_out] = interpreter.get_output_details()

    inference_accuracy = 0.
    num_processed = 0
    for bottleneck, truth in self._generate_batches(self.val_bottlenecks,
                                                    self.val_labels):
      batch_size = bottleneck.shape[0]
      interpreter.resize_tensor_input(bottleneck_in['index'],
                                      (batch_size,) + BOTTLENECK_SHAPE)
      interpreter.allocate_tensors()

      interpreter.set_tensor(bottleneck_in['index'], bottleneck)
      for variable_in, variable_value in zip(variable_ins, self.variables):
        interpreter.set_tensor(variable_in['index'], variable_value)
      interpreter.invoke()

      preds = interpreter.get_tensor(y_out['index'])

      acc = (np.argmax(preds, axis=1) == np.argmax(truth,
                                                   axis=1)).sum() / batch_size
      inference_accuracy += acc * batch_size
      num_processed += batch_size

    inference_accuracy /= num_processed
    return inference_accuracy


def make_finite(data_gen):
  """An adapter for Keras data generators that makes them finite.

  The default behavior in Keras is to keep looping infinitely through
  the data.

  Args:
    data_gen: An infinite Keras data generator.

  Yields:
    Same values as the parameter generator.
  """
  num_samples = data_gen.samples
  num_processed = 0
  for batch in data_gen:
    batch_size = batch[0].shape[0]
    if batch_size + num_processed > num_samples:
      batch_size = num_samples - num_processed
      should_stop = True
    else:
      should_stop = False
    if batch_size == 0:
      return

    batch = tuple(x[:batch_size] for x in batch)
    yield batch
    num_processed += batch_size
    if should_stop:
      return


# TODO(b/135138207) investigate if we can get rid of this.
def pad_batch(batch, batch_size):
  """Resize batch to a given size, tiling present samples over missing.

  Example:
    Suppose batch_size is 5, batch is [1, 2].
    Then the return value is [1, 2, 1, 2, 1].

  Args:
    batch: An ndarray with first dimension size <= batch_size.
    batch_size: Desired size for first dimension.

  Returns:
    An ndarray of the same shape, except first dimension has
    the desired size.
  """
  padded = np.zeros((batch_size,) + batch.shape[1:], dtype=batch.dtype)
  next_idx = 0
  while next_idx < batch_size:
    fill_len = min(batch.shape[0], batch_size - next_idx)
    padded[next_idx:next_idx + fill_len] = batch[:fill_len]
    next_idx += fill_len
  return padded


class ModelCorrectnessTest(unittest.TestCase):

  @classmethod
  def setUpClass(cls):
    super(ModelCorrectnessTest, cls).setUpClass()
    zip_file = tf.keras.utils.get_file(
        origin=DATASET_URL, fname='flower_photos.tgz', extract=True)
    cls.dataset_dir = os.path.join(os.path.dirname(zip_file), 'flower_photos')

    mobilenet_dir = tempfile.mkdtemp('tflite-transfer-test')
    mobilenet_keras = tf.keras.applications.MobileNetV2(
        input_shape=(IMAGE_SIZE, IMAGE_SIZE, 3),
        include_top=False,
        weights='imagenet')
    tf.keras.experimental.export_saved_model(mobilenet_keras, mobilenet_dir)
    cls.mobilenet_dir = mobilenet_dir

  def setUp(self):
    super(ModelCorrectnessTest, self).setUp()
    self.mobilenet_dir = ModelCorrectnessTest.mobilenet_dir
    self.dataset_dir = ModelCorrectnessTest.dataset_dir

  def test_mobilenet_v2_saved_model_and_softmax_classifier(self):
    base_model = bases.SavedModelBase(self.mobilenet_dir)
    head_model = heads.SoftmaxClassifierHead(BATCH_SIZE, BOTTLENECK_SHAPE,
                                             NUM_CLASSES)
    optimizer = optimizers.SGD(LEARNING_RATE)
    model = TransferModel(self.dataset_dir, base_model, head_model, optimizer)
    self.assertModelAchievesAccuracy(model, 0.80)

  def test_mobilenet_v2_saved_model_quantized_and_softmax_classifier(self):
    base_model = bases.SavedModelBase(self.mobilenet_dir, quantize=True)
    head_model = heads.SoftmaxClassifierHead(BATCH_SIZE, BOTTLENECK_SHAPE,
                                             NUM_CLASSES)
    optimizer = optimizers.SGD(LEARNING_RATE)
    model = TransferModel(self.dataset_dir, base_model, head_model, optimizer)
    self.assertModelAchievesAccuracy(model, 0.80)

  def test_mobilenet_v2_base_and_softmax_classifier(self):
    base_model = bases.MobileNetV2Base()
    head_model = heads.SoftmaxClassifierHead(BATCH_SIZE, BOTTLENECK_SHAPE,
                                             NUM_CLASSES)
    optimizer = optimizers.SGD(LEARNING_RATE)
    model = TransferModel(self.dataset_dir, base_model, head_model, optimizer)
    self.assertModelAchievesAccuracy(model, 0.80)

  def test_mobilenet_v2_base_and_softmax_classifier_l2(self):
    base_model = bases.MobileNetV2Base()
    head_model = heads.SoftmaxClassifierHead(
        BATCH_SIZE, BOTTLENECK_SHAPE, NUM_CLASSES, l2_reg=0.1)
    optimizer = optimizers.SGD(LEARNING_RATE)
    model = TransferModel(self.dataset_dir, base_model, head_model, optimizer)
    self.assertModelAchievesAccuracy(model, 0.80)

  def test_mobilenet_v2_base_quantized_and_softmax_classifier(self):
    base_model = bases.MobileNetV2Base(quantize=True)
    head_model = heads.SoftmaxClassifierHead(BATCH_SIZE, BOTTLENECK_SHAPE,
                                             NUM_CLASSES)
    optimizer = optimizers.SGD(LEARNING_RATE)
    model = TransferModel(self.dataset_dir, base_model, head_model, optimizer)
    self.assertModelAchievesAccuracy(model, 0.80)

  def test_mobilenet_v2_base_and_softmax_classifier_adam(self):
    base_model = bases.MobileNetV2Base()
    head_model = heads.SoftmaxClassifierHead(BATCH_SIZE, BOTTLENECK_SHAPE,
                                             NUM_CLASSES)
    optimizer = optimizers.Adam()
    model = TransferModel(self.dataset_dir, base_model, head_model, optimizer)
    self.assertModelAchievesAccuracy(model, 0.80)

  def assertModelAchievesAccuracy(self, model, target_accuracy, num_epochs=30):
    model.prepare_bottlenecks()
    print('Bottlenecks prepared')
    history = model.train_head(num_epochs)
    print('Training completed, history = {}'.format(history))
    accuracy = model.measure_inference_accuracy()
    print('Final accuracy = {:.2f}'.format(accuracy))
    self.assertGreater(accuracy, target_accuracy)


if __name__ == '__main__':
  unittest.main()
